from backtrader.stores.binancestore import CCXTStore
import backtrader as bt
from backtrader.utils import logger, datehelper
from datetime import datetime, timedelta
import argparse
import json
import os


class OkexRun:
    def __init__(self):
        self.logger = logger.getLogger(__name__)
        args = self.parse_args()
        cfg_name = args.cfg

        prom_dir = os.path.dirname(__file__)
        cfg_file = os.path.join(prom_dir, cfg_name)
        with open(cfg_file, "r") as f:
            self.params = json.load(f)

        if args.account:
            account_name = args.account
        else:
            account_name = self.params["account"]["name"]

        self.account = bt.AccountMgt().get_account(account_name)
        self.logger.info("%s Input Args %s", "*" * 15, "*" * 15)
        self.print_params()

        token = self.params["alert"]["token"]
        alert_enable = True if self.params["alert"]["enable"].upper() == "TRUE" else False

        self.alert = bt.utils.Alert(token, alert_enable)

        if args.asset:
            self.asset = args.asset
        else:
            self.asset = self.params["asset"]["name"]

        if args.type:
            self.asset_type = args.type
        else:
            self.asset_type = self.params["asset"]["type"]

        self.currency = self.asset.split('/')[0]
        self.asset_name = self.asset.replace("/", "")

        if args.timeframe:
            self.time_frame = args.timeframe
        else:
            self.time_frame = self.params["strategy"]["timeframe"]

        self.resample = args.resample

        if args.lever:
            self.lever = int(args.lever)
        else:
            self.lever = self.params["account"]["leverage"]

        self.exchange = self.params["exchange"]
        self.posSide = self.params["account"]["posSide"]
        self.store = None
        self.init_exchange()
        self.init_account_setting()

    def print_account_config(self):
        exchange = self.store.exchange
        result = exchange.private_get_account_config()
        rs = result['data'][0]
        self.logger.info("Account: uid[账户ID]=%s", rs["uid"])
        self.logger.info("Account: acctLv[账户等级]=%s", rs["acctLv"])
        self.logger.info("Account: posMode[持仓方式]=%s", rs["posMode"])
        self.logger.info("Account: level[用户等级]=%s", rs["level"])
        self.logger.info("Account: levelTmp[用户临时等级]=%s", rs["levelTmp"])

        if self.asset_type.upper() != "SPOT":
            rs = exchange.private_get_account_leverage_info(params={
                "instId": self.asset,
                "mgnMode": "isolated"
            })
            for k in rs["data"]:
                self.logger.info("Account: instId[产品ID]=%s", k["instId"])
                self.logger.info("Account: mgnMode[保证金模式]=%s", k["mgnMode"])
                self.logger.info("Account: posSide[持仓方向]=%s", k["posSide"])
                self.logger.info("Account: lever[杠杆倍数]=%s", k["lever"])

    def print_strategy_config(self):
        params = self.params["strategy_params"]
        for k in params.keys():
            self.logger.info("Strategy: %s=%s", k, params[k])

    def init_exchange(self):
        # Create our store
        config = {
            "apiKey": self.account.apiKey,
            "secret": self.account.secret,
            "password": self.account.password,
            "options": {
                "createMarketBuyOrderRequiresPrice": True,
            },
            "enableRateLimit": True
        }

        self.store = CCXTStore(exchange=self.exchange, currency=self.currency, config=config, retries=5)

    def init_account_setting(self):
        """
        初始化账户设置
        1. 设置持仓模式
        2. 设置合约的杠杆倍数
        """
        # 设置杠杆水平

        exchange = self.store.exchange
        asset = self.asset

        result = exchange.private_get_account_config()
        rs = result['data'][0]

        if rs["posMode"] != "long_short_mode":
            exchange.private_post_account_set_position_mode(params={
                "posMode": "long_short_mode"
            })

        if self.asset_type.upper() != "SPOT":
            rs = exchange.private_get_account_leverage_info(params={
                "instId": self.asset,
                "mgnMode": "isolated"
            })
            long_set, short_set = True, True
            for k in rs["data"]:
                if k["posSide"] == "long":
                    long_set = float(k["lever"]) != self.lever

                if k["posSide"] == "short":
                    short_set = float(k["lever"]) != self.lever

            if long_set:
                exchange.private_post_account_set_leverage(params={
                    "instId": asset,
                    "lever": str(self.lever),
                    "mgnMode": "isolated",
                    "posSide": "long"
                })

            if short_set:
                exchange.private_post_account_set_leverage(params={
                    "instId": asset,
                    "lever": str(self.lever),
                    "mgnMode": "isolated",
                    "posSide": "short"
                })

    def print_params(self):
        self.logger.info(self.params)

    def run(self, pargs=None):
        args = self.parse_args()
        timeframe, compression, total_minutes = bt.utils.datehelper.parse_timeframe(self.time_frame)

        lookback = self.params["strategy"]["lookback"]
        history_minutes = (lookback + 1) * total_minutes
        hist_start_date = datetime.utcnow() - timedelta(minutes=history_minutes)

        ohlcv_limit = lookback + 1
        self.logger.info("%s Program Args %s", "*" * 15, "*" * 15)
        self.logger.info("Start: name=%s", self.asset_name)
        self.logger.info("Start: exchange=%s", self.exchange)
        self.logger.info("Start: asset=%s", self.asset)
        self.logger.info("Start: currency=%s", self.currency)
        self.logger.info("Start: lever=%s", self.lever)
        self.logger.info("Start: timeframe=%s", timeframe)
        self.logger.info("Start: compression=%s", compression)
        self.logger.info("Start: history_minutes=%s", history_minutes)
        self.logger.info("Start: hist_start_date=%s", hist_start_date)
        self.logger.info("Start: limit=%s", ohlcv_limit)
        self.logger.info("%s", "*" * 40)

        self.logger.info("%s Account Args %s", "*" * 15, "*" * 15)
        self.print_account_config()

        self.logger.info("%s Strategy Args %s", "*" * 15, "*" * 15)
        self.print_strategy_config()

        cerebro = bt.Cerebro(quicknotify=True)

        strategy_code = self.params["strategy"]["code"]

        if args.name:
            strategy_name = args.name
        else:
            strategy_name = self.params["strategy"]["name"]

        strat = __import__(strategy_code)
        strat_names = [e for e in strat.__dict__.keys() if e.endswith("Strategy")]
        if not strat_names:
            raise Exception(f"策略名称必须以Strategy结尾.")
        strat_clz = strat_names[0]

        init_params = self.params["strategy_params"]
        init_params["name"] = strategy_name
        init_params["token"] = self.alert.token
        init_params["enable"] = self.alert.enable
        init_params["leverage"] = self.lever

        self.logger.info("Load Strategy: class=%s", strat_clz)
        self.logger.info("Load Strategy: name=%s", strategy_name)
        self.logger.info("Load Strategy: param=%s", init_params)


        # # Add the strategy
        cerebro.addstrategy(strat.__dict__[strat_clz], **init_params)

        broker_mapping = {
            'order_types': {
                bt.Order.Market: 'market',
                bt.Order.Limit: 'limit',
                bt.Order.Stop: 'stop-loss',  # stop-loss for kraken, stop for bitmex
                bt.Order.StopLimit: 'stop limit'
            },
            'mappings': {
                'closed_order': {
                    'key': 'status',
                    'value': 'closed'
                },
                'canceled_order': {
                    'key': 'status',
                    'value': 'canceled'
                }
            }
        }

        broker = self.store.getbroker(broker_mapping=broker_mapping)
        cerebro.setbroker(broker)

        data = self.store.getdata(dataname=self.asset, name=self.asset_name,
                                  timeframe=timeframe, fromdate=hist_start_date,
                                  compression=compression, ohlcv_limit=ohlcv_limit, drop_newest=True)  # , historical=True)

        # Add the feed
        cerebro.adddata(data, name=self.asset_name)

        if self.resample:
            # 这里必须设置 bar2edge=False, 否则会报错
            _resample, _compression, _ = self.parse_timeframe(self.resample)
            cerebro.resampledata(dataname=data, name=self.asset_name + f"_{self.resample}",
                                 timeframe=_resample, compression=_compression, bar2edge=False)


        try:
            # Run the strategy
            dtstr = datehelper.date2str(datetime.now())
            self.alert.send(self.alert.format().format(name=strategy_name,
                                                       asset=self.asset,
                                                       date=dtstr,
                                                       action="程序启动",
                                                       detail="无"))

            cerebro.run(runonece=False)
        except Exception as e:
            dtstr = datehelper.date2str(datetime.now())
            self.alert.send(self.alert.format().format(name=strategy_name,
                                                       asset=self.asset,
                                                       date=dtstr,
                                                       action="程序崩溃",
                                                       detail="无"))
            raise e


    def parse_args(self, pargs=None):
        parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
                                         description='Mercury strategy')
        parser.add_argument('--cfg', required=True, default='cfg')
        parser.add_argument('--name', required=True, default=None)
        parser.add_argument('--account', required=False, default=None)
        parser.add_argument('--asset', required=False, default=None)
        parser.add_argument('--type', required=False, default=None)
        parser.add_argument('--lever', required=False, default=None)
        parser.add_argument('--timeframe', required=False, default=None)
        parser.add_argument('--resample', required=False, default=None)

        return parser.parse_args(pargs)


if __name__ == "__main__":
    OkexRun().run()
